<!DOCTYPE HTML>
<html lang="zh-CN">


<head>
    <meta charset="utf-8">
    <meta name="keywords" content="深度学习——图像分类入门PartⅠ, J Sir">
    <meta name="description" content="">
    <meta http-equiv="X-UA-Compatible" content="IE=edge">
    <meta name="viewport" content="width=device-width, initial-scale=1.0, user-scalable=no">
    <meta name="renderer" content="webkit|ie-stand|ie-comp">
    <meta name="mobile-web-app-capable" content="yes">
    <meta name="format-detection" content="telephone=no">
    <meta name="apple-mobile-web-app-capable" content="yes">
    <meta name="apple-mobile-web-app-status-bar-style" content="black-translucent">
    <!-- Global site tag (gtag.js) - Google Analytics -->


    <title>深度学习——图像分类入门PartⅠ | J Sir</title>
    <link rel="icon" type="image/png" href="/favicon.png">

    <link rel="stylesheet" type="text/css" href="/libs/awesome/css/all.css">
    <link rel="stylesheet" type="text/css" href="/libs/materialize/materialize.min.css">
    <link rel="stylesheet" type="text/css" href="/libs/aos/aos.css">
    <link rel="stylesheet" type="text/css" href="/libs/animate/animate.min.css">
    <link rel="stylesheet" type="text/css" href="/libs/lightGallery/css/lightgallery.min.css">
    <link rel="stylesheet" type="text/css" href="/css/matery.css">
    <link rel="stylesheet" type="text/css" href="/css/my.css">

    <script src="/libs/jquery/jquery.min.js"></script>

<meta name="generator" content="Hexo 6.0.0">
<style>.github-emoji { position: relative; display: inline-block; width: 1.2em; min-height: 1.2em; overflow: hidden; vertical-align: top; color: transparent; }  .github-emoji > span { position: relative; z-index: 10; }  .github-emoji img, .github-emoji .fancybox { margin: 0 !important; padding: 0 !important; border: none !important; outline: none !important; text-decoration: none !important; user-select: none !important; cursor: auto !important; }  .github-emoji img { height: 1.2em !important; width: 1.2em !important; position: absolute !important; left: 50% !important; top: 50% !important; transform: translate(-50%, -50%) !important; user-select: none !important; cursor: auto !important; } .github-emoji-fallback { color: inherit; } .github-emoji-fallback img { opacity: 0 !important; }</style>
</head>




<body>
    <header class="navbar-fixed">
    <nav id="headNav" class="bg-color nav-transparent">
        <div id="navContainer" class="nav-wrapper container">
            <div class="brand-logo">
                <a href="/" class="waves-effect waves-light">
                    
                    <img src="/medias/logo.png" class="logo-img" alt="LOGO">
                    
                    <span class="logo-span">J Sir</span>
                </a>
            </div>
            

<a href="#" data-target="mobile-nav" class="sidenav-trigger button-collapse"><i class="fas fa-bars"></i></a>
<ul class="right nav-menu">
  
  <li class="hide-on-med-and-down nav-item">
    
    <a href="/" class="waves-effect waves-light">
      
      <i class="fas fa-home" style="zoom: 0.6;"></i>
      
      <span>首页</span>
    </a>
    
  </li>
  
  <li class="hide-on-med-and-down nav-item">
    
    <a href="/tags" class="waves-effect waves-light">
      
      <i class="fas fa-tags" style="zoom: 0.6;"></i>
      
      <span>标签</span>
    </a>
    
  </li>
  
  <li class="hide-on-med-and-down nav-item">
    
    <a href="/categories" class="waves-effect waves-light">
      
      <i class="fas fa-bookmark" style="zoom: 0.6;"></i>
      
      <span>分类</span>
    </a>
    
  </li>
  
  <li class="hide-on-med-and-down nav-item">
    
    <a href="/archives" class="waves-effect waves-light">
      
      <i class="fas fa-archive" style="zoom: 0.6;"></i>
      
      <span>归档</span>
    </a>
    
  </li>
  
  <li class="hide-on-med-and-down nav-item">
    
    <a href="/about" class="waves-effect waves-light">
      
      <i class="fas fa-user-circle" style="zoom: 0.6;"></i>
      
      <span>关于</span>
    </a>
    
  </li>
  
  <li class="hide-on-med-and-down nav-item">
    
    <a href="/contact" class="waves-effect waves-light">
      
      <i class="fas fa-comments" style="zoom: 0.6;"></i>
      
      <span>留言板</span>
    </a>
    
  </li>
  
  <li class="hide-on-med-and-down nav-item">
    
    <a href="/friends" class="waves-effect waves-light">
      
      <i class="fas fa-address-book" style="zoom: 0.6;"></i>
      
      <span>友情链接</span>
    </a>
    
  </li>
  
  <li>
    <a href="#searchModal" class="modal-trigger waves-effect waves-light">
      <i id="searchIcon" class="fas fa-search" title="搜索" style="zoom: 0.85;"></i>
    </a>
  </li>
</ul>


<div id="mobile-nav" class="side-nav sidenav">

    <div class="mobile-head bg-color">
        
        <img src="/medias/logo.png" class="logo-img circle responsive-img">
        
        <div class="logo-name">J Sir</div>
        <div class="logo-desc">
            
            Never really desperate, only the lost of the soul.
            
        </div>
    </div>

    

    <ul class="menu-list mobile-menu-list">
        
        <li class="m-nav-item">
	  
		<a href="/" class="waves-effect waves-light">
			
			    <i class="fa-fw fas fa-home"></i>
			
			首页
		</a>
          
        </li>
        
        <li class="m-nav-item">
	  
		<a href="/tags" class="waves-effect waves-light">
			
			    <i class="fa-fw fas fa-tags"></i>
			
			标签
		</a>
          
        </li>
        
        <li class="m-nav-item">
	  
		<a href="/categories" class="waves-effect waves-light">
			
			    <i class="fa-fw fas fa-bookmark"></i>
			
			分类
		</a>
          
        </li>
        
        <li class="m-nav-item">
	  
		<a href="/archives" class="waves-effect waves-light">
			
			    <i class="fa-fw fas fa-archive"></i>
			
			归档
		</a>
          
        </li>
        
        <li class="m-nav-item">
	  
		<a href="/about" class="waves-effect waves-light">
			
			    <i class="fa-fw fas fa-user-circle"></i>
			
			关于
		</a>
          
        </li>
        
        <li class="m-nav-item">
	  
		<a href="/contact" class="waves-effect waves-light">
			
			    <i class="fa-fw fas fa-comments"></i>
			
			留言板
		</a>
          
        </li>
        
        <li class="m-nav-item">
	  
		<a href="/friends" class="waves-effect waves-light">
			
			    <i class="fa-fw fas fa-address-book"></i>
			
			友情链接
		</a>
          
        </li>
        
        
        <li><div class="divider"></div></li>
        <li>
            <a href="https://github.com/jy741" class="waves-effect waves-light" target="_blank">
                <i class="fab fa-github-square fa-fw"></i>Fork Me
            </a>
        </li>
        
    </ul>
</div>


        </div>

        
            <style>
    .nav-transparent .github-corner {
        display: none !important;
    }

    .github-corner {
        position: absolute;
        z-index: 10;
        top: 0;
        right: 0;
        border: 0;
        transform: scale(1.1);
    }

    .github-corner svg {
        color: #0f9d58;
        fill: #fff;
        height: 64px;
        width: 64px;
    }

    .github-corner:hover .octo-arm {
        animation: a 0.56s ease-in-out;
    }

    .github-corner .octo-arm {
        animation: none;
    }

    @keyframes a {
        0%,
        to {
            transform: rotate(0);
        }
        20%,
        60% {
            transform: rotate(-25deg);
        }
        40%,
        80% {
            transform: rotate(10deg);
        }
    }
</style>

<a href="https://github.com/jy741" class="github-corner tooltipped hide-on-med-and-down" target="_blank"
   data-tooltip="Fork Me" data-position="left" data-delay="50">
    <svg viewBox="0 0 250 250" aria-hidden="true">
        <path d="M0,0 L115,115 L130,115 L142,142 L250,250 L250,0 Z"></path>
        <path d="M128.3,109.0 C113.8,99.7 119.0,89.6 119.0,89.6 C122.0,82.7 120.5,78.6 120.5,78.6 C119.2,72.0 123.4,76.3 123.4,76.3 C127.3,80.9 125.5,87.3 125.5,87.3 C122.9,97.6 130.6,101.9 134.4,103.2"
              fill="currentColor" style="transform-origin: 130px 106px;" class="octo-arm"></path>
        <path d="M115.0,115.0 C114.9,115.1 118.7,116.5 119.8,115.4 L133.7,101.6 C136.9,99.2 139.9,98.4 142.2,98.6 C133.8,88.0 127.5,74.4 143.8,58.0 C148.5,53.4 154.0,51.2 159.7,51.0 C160.3,49.4 163.2,43.6 171.4,40.1 C171.4,40.1 176.1,42.5 178.8,56.2 C183.1,58.6 187.2,61.8 190.9,65.4 C194.5,69.0 197.7,73.2 200.1,77.6 C213.8,80.2 216.3,84.9 216.3,84.9 C212.7,93.1 206.9,96.0 205.4,96.6 C205.1,102.4 203.0,107.8 198.3,112.5 C181.9,128.9 168.3,122.5 157.7,114.1 C157.9,116.9 156.7,120.9 152.7,124.9 L141.0,136.5 C139.8,137.7 141.6,141.9 141.8,141.8 Z"
              fill="currentColor" class="octo-body"></path>
    </svg>
</a>
        
    </nav>

</header>

    



<div class="bg-cover pd-header post-cover" style="background-image: url('/medias/featureimages/11.jpg')">
    <div class="container" style="right: 0px;left: 0px;">
        <div class="row">
            <div class="col s12 m12 l12">
                <div class="brand">
                    <h1 class="description center-align post-title">深度学习——图像分类入门PartⅠ</h1>
                </div>
            </div>
        </div>
    </div>
</div>




<main class="post-container content">

    
    <link rel="stylesheet" href="/libs/tocbot/tocbot.css">
<style>
    #articleContent h1::before,
    #articleContent h2::before,
    #articleContent h3::before,
    #articleContent h4::before,
    #articleContent h5::before,
    #articleContent h6::before {
        display: block;
        content: " ";
        height: 100px;
        margin-top: -100px;
        visibility: hidden;
    }

    #articleContent :focus {
        outline: none;
    }

    .toc-fixed {
        position: fixed;
        top: 64px;
    }

    .toc-widget {
        width: 345px;
        padding-left: 20px;
    }

    .toc-widget .toc-title {
        padding: 35px 0 15px 17px;
        font-size: 1.5rem;
        font-weight: bold;
        line-height: 1.5rem;
    }

    .toc-widget ol {
        padding: 0;
        list-style: none;
    }

    #toc-content {
        padding-bottom: 30px;
        overflow: auto;
    }

    #toc-content ol {
        padding-left: 10px;
    }

    #toc-content ol li {
        padding-left: 10px;
    }

    #toc-content .toc-link:hover {
        color: #42b983;
        font-weight: 700;
        text-decoration: underline;
    }

    #toc-content .toc-link::before {
        background-color: transparent;
        max-height: 25px;

        position: absolute;
        right: 23.5vw;
        display: block;
    }

    #toc-content .is-active-link {
        color: #42b983;
    }

    #floating-toc-btn {
        position: fixed;
        right: 15px;
        bottom: 76px;
        padding-top: 15px;
        margin-bottom: 0;
        z-index: 998;
    }

    #floating-toc-btn .btn-floating {
        width: 48px;
        height: 48px;
    }

    #floating-toc-btn .btn-floating i {
        line-height: 48px;
        font-size: 1.4rem;
    }
</style>
<div class="row">
    <div id="main-content" class="col s12 m12 l9">
        <!-- 文章内容详情 -->
<div id="artDetail">
    <div class="card">
        <div class="card-content article-info">
            <div class="row tag-cate">
                <div class="col s7">
                    
                    <div class="article-tag">
                        
                            <a href="/tags/python/">
                                <span class="chip bg-color">python</span>
                            </a>
                        
                            <a href="/tags/%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/">
                                <span class="chip bg-color">深度学习</span>
                            </a>
                        
                    </div>
                    
                </div>
                <div class="col s5 right-align">
                    
                </div>
            </div>

            <div class="post-info">
                
                <div class="post-date info-break-policy">
                    <i class="far fa-calendar-minus fa-fw"></i>发布日期:&nbsp;&nbsp;
                    2022-10-19
                </div>
                

                

                
                <div class="info-break-policy">
                    <i class="far fa-file-word fa-fw"></i>文章字数:&nbsp;&nbsp;
                    9.8k
                </div>
                

                

                
                    <div id="busuanzi_container_page_pv" class="info-break-policy">
                        <i class="far fa-eye fa-fw"></i>阅读次数:&nbsp;&nbsp;
                        <span id="busuanzi_value_page_pv"></span>
                    </div>
				
            </div>
        </div>
        <hr class="clearfix">

        
        <!-- 是否加载使用自带的 prismjs. -->
        <link rel="stylesheet" href="/libs/prism/prism.css">
        

        

        <div class="card-content article-card-content">
            <div id="articleContent">
                <h1 id="EverydayOneCat"><a href="#EverydayOneCat" class="headerlink" title="EverydayOneCat"></a>EverydayOneCat</h1><p>⬆️??!??!⬆️ 😯😯⁉️ 🐯⬇️</p>
<p><img src="https://pluto-1300780100.cos.ap-nanjing.myqcloud.com/img/96afa2fe24c53c8f6318fff90ea2bd8a14576148.png@1036w.webp" alt="96afa2fe24c53c8f6318fff90ea2bd8a14576148.png@1036w" style="zoom:50%;"></p>
<span id="more"></span>
<h1 id="LeNet"><a href="#LeNet" class="headerlink" title="LeNet"></a>LeNet</h1><h2 id="LeNet介绍"><a href="#LeNet介绍" class="headerlink" title="LeNet介绍"></a>LeNet介绍</h2><p>LeNet可以说是第一个运用到卷积神经网络CNN的模型，它包含了深度学习的基本模块：卷积层，池化层，全连接层。是其他深度学习模型的基础， 这里我们对LeNet进行深入分析。同时，通过实例分析，加深对与卷积层和池化层的理解。</p>
<p>LeNet共有7层，不包含输入，每层都包含可训练参数；每个层有多个Feature Map，每个Feature Map是通过一种卷积滤波器提取输入的一种特征，然后每个Feature Map有多个神经元。</p>
<p><img src="https://pluto-1300780100.cos.ap-nanjing.myqcloud.com/img/image-20221019154253998.png" alt="image-20221019154253998"></p>
<p>或者我们用一张更加直观的图来展现LeNet模型的各个层：</p>
<p><img src="https://pluto-1300780100.cos.ap-nanjing.myqcloud.com/img/image-20221019212915742.png" alt="image-20221019212915742"></p>
<h2 id="model"><a href="#model" class="headerlink" title="model"></a>model</h2><p>基于上图我们可以轻松写出模型：</p>
<pre class="line-numbers language-python" data-language="python"><code class="language-python">class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.model = nn.Sequential(nn.Conv2d(3,16,5),
                                   nn.ReLU(),
                                   nn.MaxPool2d(2),
                                   nn.Conv2d(16,32,5),
                                   nn.ReLU(),
                                   nn.MaxPool2d(2),
                                   nn.Flatten(),
                                   nn.Linear(32*5*5,120),
                                   nn.ReLU(),
                                   nn.Linear(120,84),
                                   nn.ReLU(),
                                   nn.Linear(84,10))

    def forward(self,x):
        x = self.model(x)
        return x<span aria-hidden="true" class="line-numbers-rows"><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span></span></code></pre>
<h2 id="train"><a href="#train" class="headerlink" title="train"></a>train</h2><p>接下来我们通过CIFAR10数据集写一个Demo来测试LeNet的实际训练效果</p>
<p><img src="https://pluto-1300780100.cos.ap-nanjing.myqcloud.com/img/image-20221019213010506.png" alt="image-20221019213010506"></p>
<pre class="line-numbers language-python" data-language="python"><code class="language-python">transforms = torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor(),
     torchvision.transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]
)

train_dataset = torchvision.datasets.CIFAR10("../data",train=True,transform=transforms,download=True)
test_dataset = torchvision.datasets.CIFAR10("../data",train=False,transform=transforms,download=True)

trainset_len = len(train_dataset)
testset_len = len(test_dataset)

train_dataloader = DataLoader(train_dataset,50,shuffle=True)
test_dataloader = DataLoader(test_dataset,50,shuffle=True)

#使用cpu还是gpu训练
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("using {} device".format(device))

#实例化模型
lenet = LeNet()
lenet = lenet.to(device)

#定义误差函数
loss_fn = nn.CrossEntropyLoss()
loss_fn = loss_fn.to(device)

#定义优化器，每次更新模型的所有参数
optimizer = torch.optim.Adam(lenet.parameters(),lr=0.001)

#定义训练次数
train_times = 10

for epch in range(train_times):
    #开始训练
    lenet.train()#模型进入训练模式
    train_loss = 0
    for step,data in enumerate(train_dataloader,start=1):
        imgs,targets = data
        imgs = imgs.to(device)
        targets = targets.to(device)
        outputs = lenet(imgs)
        loss = loss_fn(outputs,targets)
        #用优化器更新参数
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

        #每500次打印一下
        if step%500 == 0:
            print("times:{} --&gt; train_loss:{}".format(step,train_loss))


    #开始测试
    lenet.eval()
    test_loss = 0
    test_arracy = 0

    with torch.no_grad():
        for step, data in enumerate(test_dataloader, start=1):
            imgs, targets = data
            imgs = imgs.to(device)
            targets = targets.to(device)
            outputs = lenet(imgs)
            loss = loss_fn(outputs,targets)
            test_loss += loss.item()

            arracy = (outputs.argmax(1)==targets).sum()
            test_arracy += arracy

        print("test_loss:{}".format(test_loss))
        print("accarcy:{}".format((test_arracy/testset_len)))



print("Finished Training")

save_path = './LeNet.pth'
torch.save(lenet.state_dict(),save_path)<span aria-hidden="true" class="line-numbers-rows"><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span></span></code></pre>
<h2 id="predict"><a href="#predict" class="headerlink" title="predict"></a>predict</h2><p>有了基于数据集训练好的模型，我们可以加载其参数对未知图片进行预测分类</p>
<pre class="line-numbers language-python" data-language="python"><code class="language-python">transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize((32,32)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])

img_path = "./frog_1.png"
image = Image.open(img_path)
image = transforms(image)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

net = LeNet()
net.load_state_dict(torch.load('LeNet.pth',map_location=torch.device('cpu')))

image = torch.reshape(image,(1,3,32,32))

with torch.no_grad():
    output = net(image)
    predict = output.argmax(1).item()

print(classes[predict])<span aria-hidden="true" class="line-numbers-rows"><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span></span></code></pre>
<h1 id="AlexNet"><a href="#AlexNet" class="headerlink" title="AlexNet"></a>AlexNet</h1><h2 id="AlexNet介绍"><a href="#AlexNet介绍" class="headerlink" title="AlexNet介绍"></a>AlexNet介绍</h2><p>AlexNet是2012年ISLVRC 2012（ImageNet Large Scale Visual RecognitionChallenge）竞赛的冠军网络，分类准确率由传统的 70%+提升到 80%+。它是由Hinton和他的学生Alex Krizhevsky设计的。也是在那年之后，深度学习开始迅速发展</p>
<p>该网络亮点在于：</p>
<ul>
<li>首次使用GPU作为网络加速训练</li>
<li>使用了ReLU激活函数，而不是传统的Sigmoid激活函数以及Tanh激活函数（Sigmoid存在梯度丢失、微分困难等问题）</li>
<li>使用了LRN局部响应归一化</li>
<li>在全连接层的前两层中使用了 Dropout 随机失活神经元操作，以减少过拟合</li>
</ul>
<p><img src="https://pluto-1300780100.cos.ap-nanjing.myqcloud.com/img/image-20221019213206995.png" alt="image-20221019213206995"></p>
<h2 id="DropOut"><a href="#DropOut" class="headerlink" title="DropOut"></a>DropOut</h2><p>我们知道在CNN中，我们经过一系列Convolution、Maxpooling后，虽然少了很多参数，但是Flatten后丢入全连接层input参数还是很多的，而参数过多就可能会导致过拟合问题。为此我们引入DropOut机制，在网络正向传播过程中随机失活一部分神经元，以减少过拟合。</p>
<p><img src="https://pluto-1300780100.cos.ap-nanjing.myqcloud.com/img/image-20221019213040457.png" alt="image-20221019213040457"></p>
<h2 id="model-1"><a href="#model-1" class="headerlink" title="model"></a>model</h2><p>我们总结一下上图AlexNet的过程，可以总结出如下表格</p>
<p><img src="https://pluto-1300780100.cos.ap-nanjing.myqcloud.com/img/image-20221019213120892.png" alt="image-20221019213120892"></p>
<p>基于此我们写出model.py</p>
<blockquote>
<p>为了加快训练，代码只使用了一半的网络参数，相当于只用了原论文中网络结构的下半部分（正好原论文中用的双GPU，我的电脑只有一块GPU）（后来我又用完整网络跑了遍，发现一半参数跟完整参数的训练结果acc相差无几）</p>
</blockquote>
<pre class="line-numbers language-python" data-language="python"><code class="language-python">class AlexNet(nn.Module):
    def __init__(self,num_classes=1000,init_weights=False):
        super(AlexNet, self).__init__()
        #卷积层
        self.features = nn.Sequential(
            nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),  # input[3, 224, 224]  output[48, 55, 55]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),  # output[48, 27, 27]
            nn.Conv2d(48, 128, kernel_size=5, padding=2),  # output[128, 27, 27]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),  # output[128, 13, 13]
            nn.Conv2d(128, 192, kernel_size=3, padding=1),  # output[192, 13, 13]
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 192, kernel_size=3, padding=1),  # output[192, 13, 13]
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 128, kernel_size=3, padding=1),  # output[128, 13, 13]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),  # output[128, 6, 6]
        )
        #全连接层
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),  #DropOut随机失活神经元
            nn.Linear(128 * 6 * 6, 2048),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(2048, 2048),
            nn.ReLU(inplace=True),
            nn.Linear(2048, num_classes),
        )
        if init_weights:
            self._initialize_weights()

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, start_dim=1)   #Tensor的Shape格式为bach、channel、h、w 我们展平不需要管bach这个维度
        x = self.classifier(x)
        return x

    #网络权重初始化，实际上 pytorch 在构建网络时会自动初始化权重
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):#如果是卷积层
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')  #初始化参数w，用的是何凯明初始化法
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):#如果是连接层
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)<span aria-hidden="true" class="line-numbers-rows"><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span></span></code></pre>
<h2 id="数据集处理"><a href="#数据集处理" class="headerlink" title="数据集处理"></a>数据集处理</h2><p>测试AlexNet网络我们采用花数据集，数据集下载地址：<a target="_blank" rel="noopener" href="http://download.tensorflow.org/example_images/flower_photos.tgz">http://download.tensorflow.org/example_images/flower_photos.tgz</a></p>
<p>包含 5 中类型的花，每种类型有600~900张图像不等。</p>
<p>由于此数据集不像 CIFAR10 那样下载时就划分好了训练集和测试集，因此需要自己划分。</p>
<p>编写split_data.py分类脚本：</p>
<pre class="line-numbers language-python" data-language="python"><code class="language-python">import os
from shutil import copy, rmtree
import random


def mk_file(file_path: str):
    if os.path.exists(file_path):
        # 如果文件夹存在，则先删除原文件夹在重新创建
        rmtree(file_path)
    os.makedirs(file_path)


def main():
    # 保证随机可复现
    random.seed(0)

    # 将数据集中10%的数据划分到验证集中
    split_rate = 0.1

    # 指向你解压后的flower_photos文件夹
    cwd = os.getcwd()
    data_root = os.path.join(cwd, "flower_data")
    origin_flower_path = os.path.join(data_root, "flower_photos")
    assert os.path.exists(origin_flower_path), "path '{}' does not exist.".format(origin_flower_path)

    flower_class = [cla for cla in os.listdir(origin_flower_path)
                    if os.path.isdir(os.path.join(origin_flower_path, cla))]

    # 建立保存训练集的文件夹
    train_root = os.path.join(data_root, "train")
    mk_file(train_root)
    for cla in flower_class:
        # 建立每个类别对应的文件夹
        mk_file(os.path.join(train_root, cla))

    # 建立保存验证集的文件夹
    val_root = os.path.join(data_root, "val")
    mk_file(val_root)
    for cla in flower_class:
        # 建立每个类别对应的文件夹
        mk_file(os.path.join(val_root, cla))

    for cla in flower_class:
        cla_path = os.path.join(origin_flower_path, cla)
        images = os.listdir(cla_path)
        num = len(images)
        # 随机采样验证集的索引
        eval_index = random.sample(images, k=int(num*split_rate))
        for index, image in enumerate(images):
            if image in eval_index:
                # 将分配至验证集中的文件复制到相应目录
                image_path = os.path.join(cla_path, image)
                new_path = os.path.join(val_root, cla)
                copy(image_path, new_path)
            else:
                # 将分配至训练集中的文件复制到相应目录
                image_path = os.path.join(cla_path, image)
                new_path = os.path.join(train_root, cla)
                copy(image_path, new_path)
            print("\r[{}] processing [{}/{}]".format(cla, index+1, num), end="")  # processing bar
        print()

    print("processing done!")


if __name__ == '__main__':
    main()<span aria-hidden="true" class="line-numbers-rows"><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span></span></code></pre>
<p>运行此脚本即可将数据分为训练集和验证集两部分</p>
<h2 id="train-1"><a href="#train-1" class="headerlink" title="train"></a>train</h2><pre class="line-numbers language-python" data-language="python"><code class="language-python">data_transforms = {
    "train": transforms.Compose([transforms.RandomResizedCrop(224),#随机裁剪 再缩放成224*224
                                 transforms.RandomHorizontalFlip(),#随机翻转，默认概率0.5
                                 transforms.ToTensor(),
                                 transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]),
    "val": transforms.Compose([transforms.Resize((224,224)),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("using {} device".format(device))


root_path = os.path.abspath(os.path.join(os.getcwd(),"../"))
image_path = os.path.join(root_path,"data","flower_data")
assert os.path.exists(image_path),"{} path does not exist".format(image_path)

#导入数据集
train_dataset = datasets.ImageFolder(root=os.path.join(image_path,"train"),transform=data_transforms["train"])
train_num = len(train_dataset)

val_dataset = datasets.ImageFolder(root=os.path.join(image_path,"val"),transform=data_transforms["val"])
val_num = len(val_dataset)

print("{} images for training , {} images for validate".format(train_num,val_num))

#{'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
flower_list = train_dataset.class_to_idx
cla_dict = dict((value,key) for (key,value) in flower_list.items())

#写入json文件，方便后续根据索引找到标签
json_str = json.dumps(cla_dict,indent=4)
with open('class_indices.json','w') as file:
    file.write(json_str)

#训练批次大小
batch_size = 32

#加载数据集
train_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True)
val_loader = DataLoader(val_dataset,batch_size=batch_size,shuffle=False)

#加载模型
net = AlexNet(num_classes=5,init_weights=True)
net = net.to(device)

#定义loss function
loss_fn = nn.CrossEntropyLoss()
loss_fn = loss_fn.to(device)

#定义优化器，选用Adam
learning_rate = 0.001
optimizer = optim.Adam(net.parameters(),lr=learning_rate)

#训练次数
times = 10
save_path = "../model/AlexNet_Flower.pth"
best_acc = 0.0
train_steps = len(train_loader)

for epoch in range(times):
    #开始训练
    net.train() #模型进入训练模式，开启DropOut
    train_loss = 0
    train_bar = tqdm(train_loader,file=sys.stdout) #进度条展示
    for step,data in enumerate(train_bar):
        imgs,labels = data
        imgs = imgs.to(device)
        labels = labels.to(device)
        outputs = net(imgs)
        loss = loss_fn(outputs,labels)
        train_loss += loss.item()

        #更新参数
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_bar.desc = "train epoch[{}/{}]   loss:{:.3f}".format(epoch+1,times,loss)


    #开始验证
    net.eval()  #进入验证模式，关闭DropOut
    acc_rate = 0.0
    acc_num = 0
    with torch.no_grad():
        val_bar = tqdm(val_loader,file=sys.stdout)
        for data in val_bar:
            imgs, labels = data
            imgs = imgs.to(device)
            labels = labels.to(device)
            outputs = net(imgs)
            acc = (outputs.argmax(1)==labels).sum()
            acc_num += acc

    acc_rate = acc_num/val_num
    print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
          (epoch + 1, train_loss / train_steps, acc_rate))

    if acc_rate &gt; best_acc:
        best_acc = acc_rate
        torch.save(net.state_dict(),save_path)


print("Finish Training,Best acc rate is {:.3f}".format(best_acc))<span aria-hidden="true" class="line-numbers-rows"><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span></span></code></pre>
<h2 id="predict-1"><a href="#predict-1" class="headerlink" title="predict"></a>predict</h2><pre class="line-numbers language-python" data-language="python"><code class="language-python">device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("using {} device".format(device))


data_transoform = transforms.Compose([transforms.Resize((224,224)),
                                      transforms.ToTensor(),
                                      transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])

img_path = "../data/flower_data/dandelion.png"
assert os.path.exists(img_path),"file {} does not exist".format(img_path)

img = Image.open(img_path)

plt.imshow(img)

# [N, C, H, W]
img = data_transoform(img)

# expand batch dimension
img = torch.unsqueeze(img,dim=0)    #扩增维度

# read class_indict
json_path = "./class_indices.json"
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)

with open(json_path,'r') as f:
    class_indict = json.load(f)

model = AlexNet(num_classes=5).to(device)

# load model weights
weight_path = "../model/AlexNet_Flower.pth"
assert os.path.exists(weight_path), "file: '{}' dose not exist.".format(weight_path)
model.load_state_dict(torch.load(weight_path))


#开始预测
model.eval()    #关闭DropOut
with torch.no_grad():
    output = torch.squeeze(model(img.to(device)))
    predict = torch.softmax(output,dim=0)
    predict_cla = torch.argmax(predict).numpy()


print_res = "class:{}  prob:{:.3}".format(class_indict[str(predict_cla)],
                                          predict[predict_cla].numpy())

plt.title(print_res)

for i in range(len(predict)):
    print("class: {:10}   prob: {:.3}".format(class_indict[str(i)],
                                              predict[i].numpy()))

plt.show()<span aria-hidden="true" class="line-numbers-rows"><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span></span></code></pre>
<p>预测结果</p>
<p><img src="https://pluto-1300780100.cos.ap-nanjing.myqcloud.com/img/image-20221019213240033.png" alt="image-20221019213240033"></p>
<h1 id="VGG"><a href="#VGG" class="headerlink" title="VGG"></a>VGG</h1><h2 id="VGG介绍"><a href="#VGG介绍" class="headerlink" title="VGG介绍"></a>VGG介绍</h2><p>VGG在2014年由牛津大学著名研究组VGG (Visual Geometry Group) 提出，斩获该年ImageNet竞赛中 Localization Task (定位任务) 第一名 和 Classification Task (分类任务) 第二名。</p>
<p>网络中的亮点：通过堆叠多个3x3的卷积核来替代大尺度卷积核（减少所需参数）</p>
<p>论文中提到，可以通过堆叠两个3x3的卷积核替代5x5的卷积核，堆叠三个3x3的卷积核替代7x7的卷积核。（即拥有相同的感受野）</p>
<p><img src="https://pluto-1300780100.cos.ap-nanjing.myqcloud.com/img/image-20221019205435083.png" alt="image-20221019205435083"></p>
<h2 id="CNN感受野"><a href="#CNN感受野" class="headerlink" title="CNN感受野"></a>CNN感受野</h2><p>在卷积神经网络中，决定某一层输出结果中一个元素所对应的输入层的区域大小，被称作感受野（receptive field）。</p>
<p>通俗的解释是，输出feature map上的一个单元 对应 输入层上的区域大小。</p>
<p>以下图为例，输出层 layer3 中一个单元 对应 输入层 layer2 上区域大小为2×2（池化操作），对应输入层 layer1 上大小为5×5</p>
<p><img src="https://pluto-1300780100.cos.ap-nanjing.myqcloud.com/img/image-20221019204645358.png" alt="image-20221019204645358"></p>
<p>感受野计算公式为$ F(i) = (F(i+1)-1) * Stride + Ksize $</p>
<p>论文中提到，可以通过堆叠两个3x3的卷积核替代5x5的卷积核，堆叠三个3x3的卷积核替代7x7的卷积核。接下来我们来证明（注：VGG网络中卷积的Stride默认为1）</p>
<script type="math/tex; mode=display">
Feature Map : F=1 \\
Conv3*3(3):F=(1-1)*1+3=3\\
Conv3*3(2):F=(3-1)*1+3=5（5×5卷积核感受野）\\
Conv3*3(1):F=(5-1)*1+3=7（7×7卷积核感受野）</script><blockquote>
<p>堆叠3×3卷积核后训练参数是否真的减少了？</p>
</blockquote>
<p>CNN参数个数 = 卷积核尺寸×卷积核深度 × 卷积核组数 = 卷积核尺寸 × 输入特征矩阵深度 × 输出特征矩阵深度</p>
<ul>
<li>使用7×7卷积核所需参数个数：$7<em>7</em>C*C=49C^2$</li>
<li>堆叠三个3×3的卷积核所需参数个数：$3<em>3</em>C<em>C+3</em>3<em>C</em>C+3<em>3</em>C*C=27C^2$</li>
</ul>
<h2 id="VGG-16"><a href="#VGG-16" class="headerlink" title="VGG-16"></a>VGG-16</h2><p>VGG网络有多个版本，一般常用的是VGG-16模型，其网络结构如下如所示：</p>
<p><img src="https://pluto-1300780100.cos.ap-nanjing.myqcloud.com/img/image-20221019205557224.png" alt="image-20221019205557224"></p>
<h2 id="model-2"><a href="#model-2" class="headerlink" title="model"></a>model</h2><p>跟AlexNet中网络模型的定义一样，VGG网络也是分为 卷积层提取特征 和 全连接层进行分类 这两个模块</p>
<p>不同的是，VGG网络有 VGG-13、VGG-16等多种网络结构，能不能将这几种结构统一成一个模型呢？以上图的A、B、D、E模型为例，其全连接层完全一样，卷积层只有卷积核个数稍有不同。</p>
<p>我们可以用字典来统一模型</p>
<pre class="line-numbers language-python" data-language="python"><code class="language-python"># official pretrain weights
model_urls = {
    'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
    'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
    'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
    'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth'
}


class VGG(nn.Module):
    def __init__(self,features,num_classes=1000,init_weights=False):
        super(VGG, self).__init__()
        self.features = features
        self.classifier = nn.Sequential(
            nn.Linear(512*7*7, 4096),
            nn.ReLU(True),
            nn.Dropout(p=0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(p=0.5),
            nn.Linear(4096, num_classes)
        )
        if init_weights:
            self._initialize_weights()

    def forward(self, x):
        # N x 3 x 224 x 224
        x = self.features(x)
        # N x 512 x 7 x 7
        x = torch.flatten(x, start_dim=1)
        # N x 512*7*7
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                # nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)




#vgg网络模型配置列表，数字表示卷积核个数，'M'表示最大池化层
cfgs = {
    'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],											# 模型A
    'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],									# 模型B
    'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],					# 模型D
    'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 	# 模型E
}

# 卷积层提取特征
def make_features(cfg:list):
    layers = []
    in_channel = 3
    for v in cfg:
        if v=='M':
            layers += [nn.MaxPool2d(kernel_size=2,stride=2)]
        else:
            layers += [nn.Conv2d(in_channel,v,kernel_size=3,padding=1),nn.ReLU(True)]
            in_channel = v

    return nn.Sequential(*layers)   #单星号(*)将参数以元组(tuple)的形式导入


def vgg(model_name="vgg16",**kwargs):# 双星号(**)将参数以字典的形式导入
    assert model_name in cfgs, "Warning: model number {} not in cfgs dict!".format(model_name)
    cfg = cfgs[model_name]

    model = VGG(make_features(cfg),**kwargs)

    return model<span aria-hidden="true" class="line-numbers-rows"><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span></span></code></pre>
<blockquote>
<p>这里遇到了python中一个奇怪的写法，参数前面加<em>或者*</em>，可能大佬觉得很简单，我从Java转过来不太了解</p>
<p>查阅资料发现其实就是可变参数的写法，具体参考<a target="_blank" rel="noopener" href="https://www.runoob.com/w3cnote/python-one-and-two-star.html">Python 函数参数前面一个星号（<em>）和两个星号（*</em>）的区别</a></p>
</blockquote>
<p>Train和Predict的实现其实和上面的AlexNet如出一辙，只是需要注意的是由于VGG网络模型较深，需要使用GPU进行训练(而且要内存大一点的GPU。</p>
<h1 id="GoogLeNet"><a href="#GoogLeNet" class="headerlink" title="GoogLeNet"></a>GoogLeNet</h1><h2 id="GoogLeNet介绍"><a href="#GoogLeNet介绍" class="headerlink" title="GoogLeNet介绍"></a>GoogLeNet介绍</h2><p>GoogLeNet在2014年由Google团队提出（与VGG网络同年，GoogLeNet中的L大写是为了致敬LeNet），斩获当年ImageNet竞赛中Classification Task (分类任务) 第一名。</p>
<p>GoogLeNet创新点：</p>
<ul>
<li>引入了Inception结构（融合不同尺度的特征信息）</li>
<li>使用1x1的卷积核进行降维以及映射处理</li>
<li>添加两个辅助分类器帮助训练</li>
<li>丢弃全连接层，使用平均池化层（大大减少模型参数）</li>
</ul>
<h2 id="Inception结构"><a href="#Inception结构" class="headerlink" title="Inception结构"></a>Inception结构</h2><p>传统的CNN结构如AlexNet、VggNet都是串联的结构，即将一系列的卷积层和池化层进行串联得到的结构。</p>
<p>GoogLeNet 提出了一种并联结构，下图是论文中提出的inception原始结构，将特征矩阵同时输入到多个分支进行处理，并将输出的特征矩阵按深度进行拼接，得到最终输出。</p>
<p><img src="https://pluto-1300780100.cos.ap-nanjing.myqcloud.com/img/20200717121029783.png" alt="Inception原始" style="zoom:67%;"></p>
<p>在 inception 的基础上，还可以加上降维功能的结构，如下图所示，在原始 inception 结构的基础上，在分支2，3，4上加入了卷积核大小为1x1的卷积层，目的是为了降维（减小深度），减少模型训练参数，减少计算量。</p>
<p><img src="https://pluto-1300780100.cos.ap-nanjing.myqcloud.com/img/image-20221020141913041.png" alt="image-20221020141913041"></p>
<blockquote>
<p>注意：每个分支所得的特征矩阵高和宽必须相同，这样才能保证最后维度相加</p>
</blockquote>
<h2 id="1×1卷积核的降维功能"><a href="#1×1卷积核的降维功能" class="headerlink" title="1×1卷积核的降维功能"></a>1×1卷积核的降维功能</h2><p>同样是对一个深度为512的特征矩阵使用64个大小为5x5的卷积核进行卷积，不使用1x1卷积核进行降维的 话一共需要819200个参数，如果使用1x1卷积核进行降维一共需要50688个参数，明显少了很多。</p>
<p><img src="https://pluto-1300780100.cos.ap-nanjing.myqcloud.com/img/20200717122403870.png" alt="在这里插入图片描述" style="zoom:80%;"></p>
<blockquote>
<p>CNN参数个数 = 卷积核尺寸×卷积核深度 × 卷积核组数 = 卷积核尺寸 × 输入特征矩阵深度 × 输出特征矩阵深度</p>
</blockquote>
<h2 id="辅助分类器（Auxiliary-Classifier）"><a href="#辅助分类器（Auxiliary-Classifier）" class="headerlink" title="辅助分类器（Auxiliary Classifier）"></a>辅助分类器（Auxiliary Classifier）</h2><p>AlexNet 和 VGG 都只有1个输出层，GoogLeNet 有3个输出层，其中的两个是辅助分类层。</p>
<p>在训练模型时，将两个辅助分类器的损失乘以权重（论文中是0.3）加到网络的整体损失上，再进行反向传播。</p>
<p>辅助分类器的两个分支有什么用呢？</p>
<ul>
<li>作用一：可以把他看做inception网络中的一个小细节，它确保了即便是隐藏单元和中间层也参与了特征计算，他们也能预测图片的类别，他在inception网络中起到一种调整的效果，并且能防止网络发生过拟合。</li>
<li>作用二：给定深度相对较大的网络，有效传播梯度反向通过所有层的能力是一个问题。通过将辅助分类器添加到这些中间层，可以期望较低阶段分类器的判别力。在训练期间，它们的损失以折扣权重（辅助分类器损失的权重是0.3）加到网络的整个损失上。</li>
</ul>
<p><img src="https://pluto-1300780100.cos.ap-nanjing.myqcloud.com/img/image-20221020142820557.png" alt="image-20221020142820557"></p>
<p>论文参数描述：</p>
<p><img src="https://pluto-1300780100.cos.ap-nanjing.myqcloud.com/img/image-20221020143035737.png" alt="image-20221020143035737"></p>
<h2 id="model-3"><a href="#model-3" class="headerlink" title="model"></a>model</h2><p>GoogLeNet完整结构</p>
<p><img src="https://pluto-1300780100.cos.ap-nanjing.myqcloud.com/img/20200717161450737.png" alt="20200717161450737"></p>
<p>下面是原论文中给出的网络参数列表</p>
<p><img src="https://pluto-1300780100.cos.ap-nanjing.myqcloud.com/img/image-20221020143146700.png" alt="image-20221020143146700" style="zoom:80%;"></p>
<p>可以看出无论是参数还是层数都很多，我们分模块来编写</p>
<p>首先是最基础的卷积层，由于每次卷积后都需要一次ReLu激活，我们整合在一起</p>
<pre class="line-numbers language-python" data-language="python"><code class="language-python">#基础卷积层（卷积+ReLU）
class BasicConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, **kwargs)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        return x<span aria-hidden="true" class="line-numbers-rows"><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span></span></code></pre>
<p>接下来编写Inception模块，所需要使用到参数有<code>#1x1</code>, <code>#3x3reduce</code>, <code>#3x3</code>, <code>#5x5reduce</code>, <code>#5x5</code>, <code>poolproj</code>，这6个参数，分别对应着所使用的卷积核个数。</p>
<p><img src="https://pluto-1300780100.cos.ap-nanjing.myqcloud.com/img/image-20221020143529394.png" alt="image-20221020143529394" style="zoom:80%;"></p>
<ul>
<li><code>#1x1</code>对应着分支1上1x1的卷积核个数</li>
<li><code>#3x3reduce</code>对应着分支2上1x1的卷积核个数</li>
<li><code>#3x3</code>对应着分支2上3x3的卷积核个数</li>
<li><code>#5x5reduce</code>对应着分支3上1x1的卷积核个数</li>
<li><code>#5x5</code>对应着分支3上5x5的卷积核个数</li>
<li><code>poolproj</code>对应着分支4上1x1的卷积核个数。</li>
</ul>
<pre class="line-numbers language-python" data-language="python"><code class="language-python"># Inception结构 每个分支所得的特征矩阵高和宽必须相同
class Inception(nn.Module):
    def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj):
        super(Inception, self).__init__()

        self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1)

        self.branch2 = nn.Sequential(
            BasicConv2d(in_channels, ch3x3red, kernel_size=1),
            BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1)   # 保证输出大小等于输入大小
        )

        self.branch3 = nn.Sequential(
            BasicConv2d(in_channels, ch5x5red, kernel_size=1),
            # 在官方的实现中，其实是3x3的kernel并不是5x5，具体可以参考下面的issue
            # Please see https://github.com/pytorch/vision/issues/906 for details.
            BasicConv2d(ch5x5red, ch5x5, kernel_size=5, padding=2)   # 保证输出大小等于输入大小
        )

        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            BasicConv2d(in_channels, pool_proj, kernel_size=1)
        )

    def forward(self, x):
        branch1 = self.branch1(x)
        branch2 = self.branch2(x)
        branch3 = self.branch3(x)
        branch4 = self.branch4(x)

        outputs = [branch1, branch2, branch3, branch4]
        return torch.cat(outputs, 1) # 按 channel 对四个分支拼接<span aria-hidden="true" class="line-numbers-rows"><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span></span></code></pre>
<p>接着是辅助分类器模块，我们对照论文给的参数来编写</p>
<blockquote>
<p>这里的training是一个布尔类型，当我们实例化一个model之后，可以通过model.train()和model.eval()来控制模型的状态，在训练模式下training=True，验证模式下training=False</p>
</blockquote>
<pre class="line-numbers language-python" data-language="python"><code class="language-python">#辅助分类器
class InceptionAux(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(InceptionAux, self).__init__()
        self.averagePool = nn.AvgPool2d(kernel_size=5, stride=3)
        self.conv = BasicConv2d(in_channels, 128, kernel_size=1)  # output[batch, 128, 4, 4]

        self.fc1 = nn.Linear(2048, 1024)
        self.fc2 = nn.Linear(1024, num_classes)

    def forward(self, x):
        # aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14
        x = self.averagePool(x)
        # aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4
        x = self.conv(x)
        # N x 128 x 4 x 4
        x = torch.flatten(x, 1)
        x = F.dropout(x, 0.5, training=self.training)
        # N x 2048
        x = F.relu(self.fc1(x), inplace=True)
        x = F.dropout(x, 0.5, training=self.training)
        # N x 1024
        x = self.fc2(x)
        # N x num_classes
        return x<span aria-hidden="true" class="line-numbers-rows"><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span></span></code></pre>
<p>最后整合起来</p>
<pre class="line-numbers language-python" data-language="python"><code class="language-python">import torch.nn as nn
import torch
import torch.nn.functional as F



class GoogLeNet(nn.Module):
    # 传入的参数中aux_logits=True表示训练过程用到辅助分类器，aux_logits=False表示验证过程不用辅助分类器
    def __init__(self, num_classes=1000, aux_logits=True, init_weights=False):
        super(GoogLeNet, self).__init__()
        self.aux_logits = aux_logits

        self.conv1 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3)
        self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

        self.conv2 = BasicConv2d(64, 64, kernel_size=1)
        self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1)
        self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

        self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32)
        self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64)
        self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

        self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64)
        self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64)
        self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64)
        self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64)
        self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128)
        self.maxpool4 = nn.MaxPool2d(3, stride=2, ceil_mode=True)

        self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)
        self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)

        if self.aux_logits:
            self.aux1 = InceptionAux(512, num_classes)
            self.aux2 = InceptionAux(528, num_classes)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(0.4)
        self.fc = nn.Linear(1024, num_classes)
        if init_weights:
            self._initialize_weights()

    def forward(self, x):
        # N x 3 x 224 x 224
        x = self.conv1(x)
        # N x 64 x 112 x 112
        x = self.maxpool1(x)
        # N x 64 x 56 x 56
        x = self.conv2(x)
        # N x 64 x 56 x 56
        x = self.conv3(x)
        # N x 192 x 56 x 56
        x = self.maxpool2(x)

        # N x 192 x 28 x 28
        x = self.inception3a(x)
        # N x 256 x 28 x 28
        x = self.inception3b(x)
        # N x 480 x 28 x 28
        x = self.maxpool3(x)
        # N x 480 x 14 x 14
        x = self.inception4a(x)
        # N x 512 x 14 x 14
        if self.training and self.aux_logits:    # eval model lose this layer
            aux1 = self.aux1(x)

        x = self.inception4b(x)
        # N x 512 x 14 x 14
        x = self.inception4c(x)
        # N x 512 x 14 x 14
        x = self.inception4d(x)
        # N x 528 x 14 x 14
        if self.training and self.aux_logits:    # eval model lose this layer
            aux2 = self.aux2(x)

        x = self.inception4e(x)
        # N x 832 x 14 x 14
        x = self.maxpool4(x)
        # N x 832 x 7 x 7
        x = self.inception5a(x)
        # N x 832 x 7 x 7
        x = self.inception5b(x)
        # N x 1024 x 7 x 7

        x = self.avgpool(x)
        # N x 1024 x 1 x 1
        x = torch.flatten(x, 1)
        # N x 1024
        x = self.dropout(x)
        x = self.fc(x)
        # N x 1000 (num_classes)
        if self.training and self.aux_logits:   # eval model lose this layer
            return x, aux2, aux1
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

# Inception结构 每个分支所得的特征矩阵高和宽必须相同
class Inception(nn.Module):
    def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj):
        super(Inception, self).__init__()

        self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1)

        self.branch2 = nn.Sequential(
            BasicConv2d(in_channels, ch3x3red, kernel_size=1),
            BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1)   # 保证输出大小等于输入大小
        )

        self.branch3 = nn.Sequential(
            BasicConv2d(in_channels, ch5x5red, kernel_size=1),
            # 在官方的实现中，其实是3x3的kernel并不是5x5，具体可以参考下面的issue
            # Please see https://github.com/pytorch/vision/issues/906 for details.
            BasicConv2d(ch5x5red, ch5x5, kernel_size=5, padding=2)   # 保证输出大小等于输入大小
        )

        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            BasicConv2d(in_channels, pool_proj, kernel_size=1)
        )

    def forward(self, x):
        branch1 = self.branch1(x)
        branch2 = self.branch2(x)
        branch3 = self.branch3(x)
        branch4 = self.branch4(x)

        outputs = [branch1, branch2, branch3, branch4]
        return torch.cat(outputs, 1) # 按 channel 对四个分支拼接


#辅助分类器
class InceptionAux(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(InceptionAux, self).__init__()
        self.averagePool = nn.AvgPool2d(kernel_size=5, stride=3)
        self.conv = BasicConv2d(in_channels, 128, kernel_size=1)  # output[batch, 128, 4, 4]

        self.fc1 = nn.Linear(2048, 1024)
        self.fc2 = nn.Linear(1024, num_classes)

    def forward(self, x):
        # aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14
        x = self.averagePool(x)
        # aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4
        x = self.conv(x)
        # N x 128 x 4 x 4
        x = torch.flatten(x, 1)
        x = F.dropout(x, 0.5, training=self.training)
        # N x 2048
        x = F.relu(self.fc1(x), inplace=True)
        x = F.dropout(x, 0.5, training=self.training)
        # N x 1024
        x = self.fc2(x)
        # N x num_classes
        return x


#基础卷积层（卷积+ReLU）
class BasicConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, **kwargs)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        return x<span aria-hidden="true" class="line-numbers-rows"><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span></span></code></pre>
<h2 id="train-2"><a href="#train-2" class="headerlink" title="train"></a>train</h2><p>训练部分跟AlexNet和VGG类似，有两点需要注意：</p>
<ol>
<li><p>实例化网络时的参数</p>
<pre class="line-numbers language-python" data-language="python"><code class="language-python">net = GoogLeNet(num_classes=5, aux_logits=True, init_weights=True)<span aria-hidden="true" class="line-numbers-rows"><span></span></span></code></pre>
</li>
<li><p>GoogLeNet的网络输出 loss 有三个部分，分别是主干输出loss、两个辅助分类器输出loss（权重0.3）</p>
<pre class="line-numbers language-python" data-language="python"><code class="language-python">logits, aux_logits2, aux_logits1 = net(images.to(device))
loss0 = loss_function(logits, labels.to(device))
loss1 = loss_function(aux_logits1, labels.to(device))
loss2 = loss_function(aux_logits2, labels.to(device))
loss = loss0 + loss1 * 0.3 + loss2 * 0.3<span aria-hidden="true" class="line-numbers-rows"><span></span><span></span><span></span><span></span><span></span></span></code></pre>
</li>
</ol>
<h2 id="predict-2"><a href="#predict-2" class="headerlink" title="predict"></a>predict</h2><p>预测部分跟AlexNet和VGG类似，需要注意在实例化模型时不需要 辅助分类器</p>
<pre class="line-numbers language-python" data-language="python"><code class="language-python"># create model
model = GoogLeNet(num_classes=5, aux_logits=False)

# load model weights
model_weight_path = "./googleNet.pth"<span aria-hidden="true" class="line-numbers-rows"><span></span><span></span><span></span><span></span><span></span></span></code></pre>
<p>但是在加载训练好的模型参数时，由于其中是包含有辅助分类器的，需要设置<code>strict=False</code></p>
<pre class="line-numbers language-python" data-language="python"><code class="language-python">missing_keys, unexpected_keys = model.load_state_dict(torch.load(model_weight_path), strict=False)<span aria-hidden="true" class="line-numbers-rows"><span></span></span></code></pre>
<h1 id="ResNet"><a href="#ResNet" class="headerlink" title="ResNet"></a>ResNet</h1><h2 id="ResNet介绍"><a href="#ResNet介绍" class="headerlink" title="ResNet介绍"></a>ResNet介绍</h2><p>原论文地址：<a target="_blank" rel="noopener" href="https://arxiv.org/abs/1512.03385">Deep Residual Learning for Image Recognition</a>（作者是CV大佬何凯明团队）</p>
<p>ResNet在2015年由微软实验室提出，斩获当年ImageNet竞赛中分类任务第一名，目标检测第一名。获得COCO数据集中目标检测第一名，图像分割第一名。</p>
<p>ResNet网络的亮点：</p>
<ul>
<li>提出 Residual 结构（残差结构），并搭建超深的网络结构（可突破1000层）</li>
<li>使用 Batch Normalization 加速训练（丢弃dropout）</li>
</ul>
<p>下图是ResNet34层模型的结构简图：</p>
<p><img src="https://pluto-1300780100.cos.ap-nanjing.myqcloud.com/img/20200719204103322.png" alt=""></p>
<h2 id="传统CNN存在的问题"><a href="#传统CNN存在的问题" class="headerlink" title="传统CNN存在的问题"></a>传统CNN存在的问题</h2><p>在ResNet网络提出之前，传统的卷积神经网络都是通过将一系列卷积层与池化层进行堆叠得到的。</p>
<p>一般我们会觉得网络越深，特征信息越丰富，模型效果应该越好。但是实验证明，当网络堆叠到一定深度时，会出现两个问题：</p>
<ol>
<li><p>梯度消失或梯度爆炸</p>
<p>梯度消失：若每一层的误差梯度小于1，反向传播时，网络越深，梯度越趋近于0<br>梯度爆炸：反之，若每一层的误差梯度大于1，反向传播时，网路越深，梯度越来越大</p>
</li>
<li><p>退化问题(degradation problem)：在解决了梯度消失、爆炸问题后，仍然存在深层网络的效果可能比浅层网络差的现象</p>
</li>
</ol>
<p>总结就是，当网络堆叠到一定深度时，反而会出现深层网络比浅层网络效果差的情况，如下图</p>
<p><img src="https://pluto-1300780100.cos.ap-nanjing.myqcloud.com/img/20200719205325378.png" alt="退化问题" style="zoom:80%;"></p>
<blockquote>
<p>对于梯度消失或梯度爆炸问题，ResNet论文提出通过数据的预处理以及在网络中使用 BN（Batch Normalization）层来解决。</p>
<p>对于退化问题，ResNet论文提出了 residual结构（残差结构）来减轻退化问题，下图是使用residual结构的卷积网络，可以看到随着网络的不断加深，效果并没有变差，而是变的更好了。（虚线是train error，实线是test error）</p>
<p><img src="https://pluto-1300780100.cos.ap-nanjing.myqcloud.com/img/20200805095831298.png" alt="" style="zoom:80%;"></p>
</blockquote>
<h2 id="Batch-Normalization"><a href="#Batch-Normalization" class="headerlink" title="Batch Normalization"></a>Batch Normalization</h2><p>我们在图像预处理过程中通常会对图像进行标准化处理，这样能够加速网络的收敛，如下图所示，对于Conv1来说输入的就是满足某一分布的特征矩阵，但对于Conv2而言输入的feature map就不一定满足某一分布规律了（注意这里所说满足某一分布规律并不是指某一个feature map的数据要满足分布规律，理论上是指整个训练样本集所对应feature map的数据要满足分布规律）。而我们Batch Normalization的目的就是使我们的feature map满足均值为0，方差为1的分布规律。</p>
<p><img src="https://pluto-1300780100.cos.ap-nanjing.myqcloud.com/img/20200221211618570.png" alt="img" style="zoom: 50%;"></p>
<p>具体原理和实现参照这篇博文：<a target="_blank" rel="noopener" href="https://blog.csdn.net/qq_37541097/article/details/104434557">https://blog.csdn.net/qq_37541097/article/details/104434557</a></p>
<p><img src="https://pluto-1300780100.cos.ap-nanjing.myqcloud.com/img/image-20221021192200974.png" alt="image-20221021192200974" style="zoom: 67%;"></p>
<p>使用BN时需要注意的问题：</p>
<ol>
<li>训练时要将traning参数设置为True，在验证时将trainning参数设置为False。在pytorch中可通过创建模型的model.train()和model.eval()方法控制。</li>
<li>batch size尽可能设置大点，设置小后表现可能很糟糕，设置的越大求的均值和方差越接近整个训练集的均值和方差。</li>
<li>建议将bn层放在卷积层（Conv）和激活层（例如Relu）之间，且卷积层不要使用偏置bias，因为没有用。</li>
</ol>
<h2 id="Residual"><a href="#Residual" class="headerlink" title="Residual"></a>Residual</h2><p>为了解决深层网络中的退化问题，可以人为地让神经网络某些层跳过下一层神经元的连接，隔层相连，弱化每层之间的强联系。这种神经网络被称为 残差网络 (ResNets)。</p>
<p>残差网络由许多隔层相连的神经元子模块组成，我们称之为 残差块 Residual block。</p>
<p><img src="C:\Users\Administrator\Desktop\学习\博客图片\image-20221021192516837.png" alt="image-20221021192516837"></p>
<p><img src="https://pluto-1300780100.cos.ap-nanjing.myqcloud.com/img/image-20221021193124160.png" alt="image-20221021193124160"></p>
<p>我们可以根据Residual残差结构写出$X_L$与上面任意一层$X_l$之间的关系</p>
<blockquote>
<p>对于resnet残差连接可以用“传话筒”游戏来通俗理解：类似于《王牌》中的传话筒，腾哥在看到了“狗中赤兔”这个词后，形象地演给花花看，花花又演给晓彤看，最后晓彤演给玲姐看，结果玲姐看完一脸懵～。可以看出，“狗中赤兔”在传递过程中信息是不断减少的，腾哥获得了最多的信息，而玲姐获得的最少，这就类似于浅层网络获得的信息多，而深层少，最终深层网络无法理解传来的信息，也就是玲姐猜不出来题。（这一现象称之为“梯度消失”，就是指信息一层层不断减少直至消失）那怎么办呢？为了解决这个问题，腾哥就跳过花花晓彤，单独给玲姐演了一遍，结果玲姐顿悟—“狗中赤兔”！这相当于浅层网络绕开中间网络，把信息直接传给了深层网络，深层网络秒懂。残差连接就是将信息直接传给深层网络，避免了浅层网络对信息的削减。（还有一种“梯度爆炸”现象，是指每一层网络传递的信息越来越多，导致深层网络直接“死机”了）</p>
</blockquote>
<h2 id="ResNet中的残差结构"><a href="#ResNet中的残差结构" class="headerlink" title="ResNet中的残差结构"></a>ResNet中的残差结构</h2><p>实际应用中，残差结构的 short cut 不一定是隔一层连接，也可以中间隔多层，ResNet所提出的残差网络中就是隔多层。</p>
<p>跟VggNet类似，ResNet也有多个不同层的版本，而残差结构也有两种对应浅层和深层网络：</p>
<div class="table-container">
<table>
<thead>
<tr>
<th></th>
<th>ResNet</th>
<th>残差结构</th>
</tr>
</thead>
<tbody>
<tr>
<td>浅层网络</td>
<td>ResNet18/34</td>
<td>BasicBlock</td>
</tr>
<tr>
<td>深层网络</td>
<td>ResNet50/101/152</td>
<td>Bottleneck</td>
</tr>
</tbody>
</table>
</div>
<p>下图中左侧残差结构称为 <strong>BasicBlock</strong>，右侧残差结构称为 <strong>Bottleneck</strong></p>
<p><img src="https://pluto-1300780100.cos.ap-nanjing.myqcloud.com/img/20200805101305631.png" alt=""></p>
<p>对于深层的 Bottleneck，1×1的卷积核起到降维和升维（特征矩阵深度）的作用，同时可以大大减少网络参数。</p>
<blockquote>
<p>可以计算一下，假设两个残差结构的输入特征和输出特征矩阵的深度都是256维，如下图：（注意左侧结构的改动）</p>
<p><img src="https://pluto-1300780100.cos.ap-nanjing.myqcloud.com/img/20200805110510978.png" alt=""></p>
<p>那么两个残差结构所需的参数为：</p>
<p>左侧：$3<em>3</em>256<em>256+3</em>3<em>256</em>256=1179648$<br>右侧：$1<em>1</em>256<em>64+3</em>3<em>64</em>64+1<em>1</em>256*64=69632$</p>
<p>明显搭建深层网络时，使用右侧的残差结构更合适。</p>
</blockquote>
<h2 id="Short-cut的维度处理"><a href="#Short-cut的维度处理" class="headerlink" title="Short cut的维度处理"></a>Short cut的维度处理</h2><p>我们来看ResNet18层网络结构图，可以发现有些残差块的 short cut 是实线的，而有些则是虚线的。</p>
<p><img src="https://pluto-1300780100.cos.ap-nanjing.myqcloud.com/img/20200805143600781.png" style="zoom:67%;"></p>
<p>这些虚线的 short cut 上通过1×1的卷积核进行了维度处理（特征矩阵在长宽方向降采样，深度方向调整成下一层残差结构所需要的channel）。</p>
<p><img src="https://pluto-1300780100.cos.ap-nanjing.myqcloud.com/img/20200805145611987.png" alt=""></p>
<p>原文的表注中已说明，conv3_x, conv4_x, conv5_x所对应的一系列残差结构的第一层残差结构都是虚线残差结构。因为这一系列残差结构的第一层都有调整输入特征矩阵shape的使命（将特征矩阵的高和宽缩减为原来的一半，将深度channel调整成下一层残差结构所需要的channel）</p>
<p>需要注意的是，对于ResNet50/101/152，其实conv2_x所对应的一系列残差结构的第一层也是虚线残差结构，因为它需要调整输入特征矩阵的channel。根据表格可知通过3x3的max pool之后输出的特征矩阵shape应该是[56, 56, 64]，但conv2_x所对应的一系列残差结构中的实线残差结构它们期望的输入特征矩阵shape是[56, 56, 256]（因为这样才能保证输入输出特征矩阵shape相同，才能将捷径分支的输出与主分支的输出进行相加）。所以第一层残差结构需要将shape从[56, 56, 64] —&gt; [56, 56, 256]。注意，这里只调整channel维度，高和宽不变（而conv3_x, conv4_x, conv5_x所对应的一系列残差结构的第一层虚线残差结构不仅要调整channel还要将高和宽缩减为原来的一半）。</p>
<ul>
<li><p>ResNet 18/34：</p>
<p><img src="https://pluto-1300780100.cos.ap-nanjing.myqcloud.com/img/20200805112731312.png" style="zoom: 80%;"></p>
</li>
<li><p>ResNet 50/101/152：</p>
<p><img src="https://pluto-1300780100.cos.ap-nanjing.myqcloud.com/img/20200805112656263.png" style="zoom:80%;"></p>
</li>
</ul>
<h2 id="model-4"><a href="#model-4" class="headerlink" title="model"></a>model</h2><p>定义ResNet18/34的残差结构，即BasicBlock</p>
<pre class="line-numbers language-python" data-language="python"><code class="language-python"># ResNet18/34的残差结构，用的是2个3x3的卷积
class BasicBlock(nn.Module):
    expansion = 1  # 残差结构中，主分支的卷积核个数是否发生变化，不变则为1

    def __init__(self, in_channel, out_channel, stride=1, downsample=None):  # downsample对应虚线残差结构
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
                               kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channel) #Batch Normalization
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
                               kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channel)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:  # 虚线残差结构，需要下采样
            identity = self.downsample(x)  # 捷径分支 short cut

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += identity
        out = self.relu(out)

        return out<span aria-hidden="true" class="line-numbers-rows"><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span></span></code></pre>
<p>定义ResNet50/101/152的残差结构，即Bottleneck</p>
<pre class="line-numbers language-python" data-language="python"><code class="language-python"># ResNet50/101/152的残差结构，用的是1x1+3x3+1x1的卷积
class Bottleneck(nn.Module):
    expansion = 4  # 残差结构中第三层卷积核个数是第一/二层卷积核个数的4倍

    def __init__(self, in_channel, out_channel, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
                               kernel_size=1, stride=1, bias=False)  # squeeze channels
        self.bn1 = nn.BatchNorm2d(out_channel)
        # -----------------------------------------
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
                               kernel_size=3, stride=stride, bias=False, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channel)
        # -----------------------------------------
        self.conv3 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel * self.expansion,
                               kernel_size=1, stride=1, bias=False)  # unsqueeze channels
        self.bn3 = nn.BatchNorm2d(out_channel * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)  # 捷径分支 short cut

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        out += identity
        out = self.relu(out)

        return out<span aria-hidden="true" class="line-numbers-rows"><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span></span></code></pre>
<p>定义ResNet网络结构</p>
<pre class="line-numbers language-python" data-language="python"><code class="language-python">class ResNet(nn.Module):
    # block = BasicBlock or Bottleneck
    # block_num为残差结构中conv2_x~conv5_x中残差块个数，是一个列表
    def __init__(self, block, blocks_num, num_classes=1000, include_top=True):
        super(ResNet, self).__init__()
        self.include_top = include_top
        self.in_channel = 64

        self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,
                               padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channel)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, blocks_num[0])             # conv2_x
        self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)  # conv3_x
        self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)  # conv4_x
        self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)  # conv5_x
        if self.include_top:
            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # output size = (1, 1)
            self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    # channel为残差结构中第一层卷积核个数
    def _make_layer(self, block, channel, block_num, stride=1):
        downsample = None

        # ResNet50/101/152的残差结构，block.expansion=4
        if stride != 1 or self.in_channel != channel * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(channel * block.expansion))

        layers = []
        layers.append(block(self.in_channel, channel, downsample=downsample, stride=stride))
        self.in_channel = channel * block.expansion

        for _ in range(1, block_num):
            layers.append(block(self.in_channel, channel))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        if self.include_top:
            x = self.avgpool(x)
            x = torch.flatten(x, 1)
            x = self.fc(x)

        return x<span aria-hidden="true" class="line-numbers-rows"><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span></span></code></pre>
<p>定义resnet18/34/50/101/152</p>
<pre class="line-numbers language-python" data-language="python"><code class="language-python">def resnet34(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnet34-333f7ec4.pth
    return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)


def resnet50(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnet50-19c8e357.pth
    return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)


def resnet101(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnet101-5d3b4d8f.pth
    return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, include_top=include_top)<span aria-hidden="true" class="line-numbers-rows"><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span></span></code></pre>
<h2 id="train-3"><a href="#train-3" class="headerlink" title="train"></a>train</h2><p>由于ResNet网络较深，直接训练的话会非常耗时，因此用迁移学习的方法导入预训练好的模型参数，以下是ResNet官方与训练好的模型下载地址：</p>
<pre class="line-numbers language-python" data-language="python"><code class="language-python">model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
    'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
    'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
    'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
    'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
}
<span aria-hidden="true" class="line-numbers-rows"><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span></span></code></pre>
<p>这里数据的预处理我们需要和官方的预处理一样，保证后面的迁移学习正确率：</p>
<pre class="line-numbers language-python" data-language="python"><code class="language-python">data_transform = {
    "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.ToTensor(),
                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
    "val": transforms.Compose([transforms.Resize(256),
                               transforms.CenterCrop(224),
                               transforms.ToTensor(),
                               transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}<span aria-hidden="true" class="line-numbers-rows"><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span></span></code></pre>
<p>载入模型时需要用到迁移学习：</p>
<pre class="line-numbers language-python" data-language="python"><code class="language-python">net = resnet34()
# load pretrain weights
# download url: https://download.pytorch.org/models/resnet34-333f7ec4.pth
model_weight_path = "./resnet34-pre.pth"
assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)
net.load_state_dict(torch.load(model_weight_path, map_location='cpu'))
# for param in net.parameters():
#     param.requires_grad = False

# change fc layer structure
in_channel = net.fc.in_features
net.fc = nn.Linear(in_channel, 5)
net.to(device)<span aria-hidden="true" class="line-numbers-rows"><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span></span></code></pre>
<h2 id="ResNext"><a href="#ResNext" class="headerlink" title="ResNext"></a>ResNext</h2><p>ResNext其实就是在ResNet的基础上吸收了Inception结构，并引入了一个新的概念——分组Group——也称为cardinality</p>
<p>如下图所示，上面是我们传统的k*k的n维卷积核，下面是我们把input分为g组，卷积核也分为g组，所需要的参数，可以看出，分组分的越多，所需要的参数就越小。</p>
<p><img src="C:\Users\Administrator\Desktop\学习\博客图片\image-20221021211941330.png" alt="image-20221021211941330" style="zoom: 67%;"></p>
<p>和之前介绍的精心设计的Inception模块不一样，ResNext的Inception每个结构使用相同的拓扑结构。</p>
<p><img src="https://pluto-1300780100.cos.ap-nanjing.myqcloud.com/img/image-20221021212525367.png" alt="image-20221021212525367"></p>
<p>——————————————————————————————————————————————</p>
<blockquote>
<p>ResNext没太搞懂，以后再完善</p>
</blockquote>
<p>——————————————————————————————————————————————</p>
<h1 id="迁移学习"><a href="#迁移学习" class="headerlink" title="迁移学习"></a>迁移学习</h1><p>迁移学习是一个比较大的领域，我们这里说的迁移学习是指神经网络训练中使用到的迁移学习。</p>
<p>在迁移学习中，我们希望利用源任务（Source Task）学到的知识帮助学习目标任务 (Target Task)。例如，一个训练好的图像分类网络能够被用于另一个图像相关的任务。再比如，一个网络在仿真环境学习的知识可以被迁移到真实环境的网络。迁移学习一个典型的例子就是载入训练好VGG网络，这个大规模分类网络能将图像分到1000个类别，然后把这个网络用于另一个任务，如医学图像分类。</p>
<p>为什么可以这么做呢？如下图所示，神经网络逐层提取图像的深层信息，这样，预训练网络就相当于一个特征提取器。</p>
<p><img src="https://pluto-1300780100.cos.ap-nanjing.myqcloud.com/img/20200815095455312.png" alt=""></p>
<p>使用迁移学习的优势：</p>
<ul>
<li>能够快速的训练出一个理想的结果</li>
<li>当数据集较小时也能训练出理想的效果</li>
</ul>
<blockquote>
<p>注意：使用别人预训练模型参数时，要注意别人的预处理方式。</p>
</blockquote>
<p>常见的迁移学习方式：</p>
<ol>
<li>载入权重后训练所有参数</li>
<li>载入权重后只训练最后几层参数</li>
<li>载入权重后在原网络基础上再添加一层全连接层，仅训练最后一个全连接层</li>
</ol>
<h1 id="结语"><a href="#结语" class="headerlink" title="结语"></a>结语</h1><p>RNG干碎T1吧🤔</p>
<iframe frameborder="no" border="0" marginwidth="0" marginheight="0" width="298" height="52" src="//music.163.com/outchain/player?type=2&amp;id=1846565957&amp;auto=0&amp;height=32"></iframe>

                
            </div>
            <hr/>

            

    <div class="reprint" id="reprint-statement">
        
            <div class="reprint__author">
                <span class="reprint-meta" style="font-weight: bold;">
                    <i class="fas fa-user">
                        文章作者:
                    </i>
                </span>
                <span class="reprint-info">
                    <a href="/about" rel="external nofollow noreferrer">J Sir</a>
                </span>
            </div>
            <div class="reprint__type">
                <span class="reprint-meta" style="font-weight: bold;">
                    <i class="fas fa-link">
                        文章链接:
                    </i>
                </span>
                <span class="reprint-info">
                    <a href="https://jy741.gitee.io/2022/10/19/shen-du-xue-xi-tu-xiang-fen-lei-ru-men-parti/">https://jy741.gitee.io/2022/10/19/shen-du-xue-xi-tu-xiang-fen-lei-ru-men-parti/</a>
                </span>
            </div>
            <div class="reprint__notice">
                <span class="reprint-meta" style="font-weight: bold;">
                    <i class="fas fa-copyright">
                        版权声明:
                    </i>
                </span>
                <span class="reprint-info">
                    本博客所有文章除特別声明外，均采用
                    <a href="https://creativecommons.org/licenses/by/4.0/deed.zh" rel="external nofollow noreferrer" target="_blank">CC BY 4.0</a>
                    许可协议。转载请注明来源
                    <a href="/about" target="_blank">J Sir</a>
                    !
                </span>
            </div>
        
    </div>

    <script async defer>
      document.addEventListener("copy", function (e) {
        let toastHTML = '<span>复制成功，请遵循本文的转载规则</span><button class="btn-flat toast-action" onclick="navToReprintStatement()" style="font-size: smaller">查看</a>';
        M.toast({html: toastHTML})
      });

      function navToReprintStatement() {
        $("html, body").animate({scrollTop: $("#reprint-statement").offset().top - 80}, 800);
      }
    </script>



            <div class="tag_share" style="display: block;">
                <div class="post-meta__tag-list" style="display: inline-block;">
                    
                        <div class="article-tag">
                            
                                <a href="/tags/python/">
                                    <span class="chip bg-color">python</span>
                                </a>
                            
                                <a href="/tags/%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/">
                                    <span class="chip bg-color">深度学习</span>
                                </a>
                            
                        </div>
                    
                </div>
                <div class="post_share" style="zoom: 80%; width: fit-content; display: inline-block; float: right; margin: -0.15rem 0;">
                    <link rel="stylesheet" type="text/css" href="/libs/share/css/share.min.css">
<div id="article-share">

    
    <div class="social-share" data-sites="twitter,facebook,google,qq,qzone,wechat,weibo,douban,linkedin" data-wechat-qrcode-helper="<p>微信扫一扫即可分享！</p>"></div>
    <script src="/libs/share/js/social-share.min.js"></script>
    

    

</div>

                </div>
            </div>
            
                <style>
    #reward {
        margin: 40px 0;
        text-align: center;
    }

    #reward .reward-link {
        font-size: 1.4rem;
        line-height: 38px;
    }

    #reward .btn-floating:hover {
        box-shadow: 0 6px 12px rgba(0, 0, 0, 0.2), 0 5px 15px rgba(0, 0, 0, 0.2);
    }

    #rewardModal {
        width: 320px;
        height: 350px;
    }

    #rewardModal .reward-title {
        margin: 15px auto;
        padding-bottom: 5px;
    }

    #rewardModal .modal-content {
        padding: 10px;
    }

    #rewardModal .close {
        position: absolute;
        right: 15px;
        top: 15px;
        color: rgba(0, 0, 0, 0.5);
        font-size: 1.3rem;
        line-height: 20px;
        cursor: pointer;
    }

    #rewardModal .close:hover {
        color: #ef5350;
        transform: scale(1.3);
        -moz-transform:scale(1.3);
        -webkit-transform:scale(1.3);
        -o-transform:scale(1.3);
    }

    #rewardModal .reward-tabs {
        margin: 0 auto;
        width: 210px;
    }

    .reward-tabs .tabs {
        height: 38px;
        margin: 10px auto;
        padding-left: 0;
    }

    .reward-content ul {
        padding-left: 0 !important;
    }

    .reward-tabs .tabs .tab {
        height: 38px;
        line-height: 38px;
    }

    .reward-tabs .tab a {
        color: #fff;
        background-color: #ccc;
    }

    .reward-tabs .tab a:hover {
        background-color: #ccc;
        color: #fff;
    }

    .reward-tabs .wechat-tab .active {
        color: #fff !important;
        background-color: #22AB38 !important;
    }

    .reward-tabs .alipay-tab .active {
        color: #fff !important;
        background-color: #019FE8 !important;
    }

    .reward-tabs .reward-img {
        width: 210px;
        height: 210px;
    }
</style>

<div id="reward">
    <a href="#rewardModal" class="reward-link modal-trigger btn-floating btn-medium waves-effect waves-light red">赏</a>

    <!-- Modal Structure -->
    <div id="rewardModal" class="modal">
        <div class="modal-content">
            <a class="close modal-close"><i class="fas fa-times"></i></a>
            <h4 class="reward-title">你的赏识是我前进的动力</h4>
            <div class="reward-content">
                <div class="reward-tabs">
                    <ul class="tabs row">
                        <li class="tab col s6 alipay-tab waves-effect waves-light"><a href="#alipay">支付宝</a></li>
                        <li class="tab col s6 wechat-tab waves-effect waves-light"><a href="#wechat">微 信</a></li>
                    </ul>
                    <div id="alipay">
                        <img src="/medias/reward/alipay.jpg" class="reward-img" alt="支付宝打赏二维码">
                    </div>
                    <div id="wechat">
                        <img src="/medias/reward/wechat.png" class="reward-img" alt="微信打赏二维码">
                    </div>
                </div>
            </div>
        </div>
    </div>
</div>

<script>
    $(function () {
        $('.tabs').tabs();
    });
</script>

            
        </div>
    </div>

    

    

    

    

    

    

    

    

<article id="prenext-posts" class="prev-next articles">
    <div class="row article-row">
        
        <div class="article col s12 m6" data-aos="fade-up">
            <div class="article-badge left-badge text-color">
                <i class="fas fa-chevron-left"></i>&nbsp;上一篇</div>
            <div class="card">
                <a href="/2022/10/28/shen-du-xue-xi-tu-xiang-fen-lei-ru-men-partii/">
                    <div class="card-image">
                        
                        
                        <img src="/medias/featureimages/10.jpg" class="responsive-img" alt="深度学习——图像分类入门PartⅡ">
                        
                        <span class="card-title">深度学习——图像分类入门PartⅡ</span>
                    </div>
                </a>
                <div class="card-content article-content">
                    <div class="summary block-with-text">
                        
                            EverydayOneCat“什么时候?” “猫!”

                        
                    </div>
                    <div class="publish-info">
                        <span class="publish-date">
                            <i class="far fa-clock fa-fw icon-date"></i>2022-10-28
                        </span>
                        <span class="publish-author">
                            
                            <i class="fas fa-user fa-fw"></i>
                            J Sir
                            
                        </span>
                    </div>
                </div>
                
                <div class="card-action article-tags">
                    
                    <a href="/tags/python/">
                        <span class="chip bg-color">python</span>
                    </a>
                    
                    <a href="/tags/%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/">
                        <span class="chip bg-color">深度学习</span>
                    </a>
                    
                </div>
                
            </div>
        </div>
        
        
        <div class="article col s12 m6" data-aos="fade-up">
            <div class="article-badge right-badge text-color">
                下一篇&nbsp;<i class="fas fa-chevron-right"></i>
            </div>
            <div class="card">
                <a href="/2022/10/15/pytorch-ru-men/">
                    <div class="card-image">
                        
                        
                        <img src="/medias/featureimages/6.jpg" class="responsive-img" alt="Pytorch入门">
                        
                        <span class="card-title">Pytorch入门</span>
                    </div>
                </a>
                <div class="card-content article-content">
                    <div class="summary block-with-text">
                        
                            本文基于B站up我是土堆所做的视频教程为基础创作的学习笔记，如有侵权请联系作者。
                        
                    </div>
                    <div class="publish-info">
                            <span class="publish-date">
                                <i class="far fa-clock fa-fw icon-date"></i>2022-10-15
                            </span>
                        <span class="publish-author">
                            
                            <i class="fas fa-user fa-fw"></i>
                            J Sir
                            
                        </span>
                    </div>
                </div>
                
                <div class="card-action article-tags">
                    
                    <a href="/tags/%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0/">
                        <span class="chip bg-color">机器学习</span>
                    </a>
                    
                    <a href="/tags/python/">
                        <span class="chip bg-color">python</span>
                    </a>
                    
                </div>
                
            </div>
        </div>
        
    </div>
</article>

</div>



<!-- 代码块功能依赖 -->
<script type="text/javascript" src="/libs/codeBlock/codeBlockFuction.js"></script>

<!-- 代码语言 -->

<script type="text/javascript" src="/libs/codeBlock/codeLang.js"></script>


<!-- 代码块复制 -->

<script type="text/javascript" src="/libs/codeBlock/codeCopy.js"></script>


<!-- 代码块收缩 -->

<script type="text/javascript" src="/libs/codeBlock/codeShrink.js"></script>


    </div>
    <div id="toc-aside" class="expanded col l3 hide-on-med-and-down">
        <div class="toc-widget card" style="background-color: white;">
            <div class="toc-title"><i class="far fa-list-alt"></i>&nbsp;&nbsp;目录</div>
            <div id="toc-content"></div>
        </div>
    </div>
</div>

<!-- TOC 悬浮按钮. -->

<div id="floating-toc-btn" class="hide-on-med-and-down">
    <a class="btn-floating btn-large bg-color">
        <i class="fas fa-list-ul"></i>
    </a>
</div>


<script src="/libs/tocbot/tocbot.min.js"></script>
<script>
    $(function () {
        tocbot.init({
            tocSelector: '#toc-content',
            contentSelector: '#articleContent',
            headingsOffset: -($(window).height() * 0.4 - 45),
            collapseDepth: Number('0'),
            headingSelector: 'h1, h2, h3, h4, h5'
        });

        // modify the toc link href to support Chinese.
        let i = 0;
        let tocHeading = 'toc-heading-';
        $('#toc-content a').each(function () {
            $(this).attr('href', '#' + tocHeading + (++i));
        });

        // modify the heading title id to support Chinese.
        i = 0;
        $('#articleContent').children('h1, h2, h3, h4, h5').each(function () {
            $(this).attr('id', tocHeading + (++i));
        });

        // Set scroll toc fixed.
        let tocHeight = parseInt($(window).height() * 0.4 - 64);
        let $tocWidget = $('.toc-widget');
        $(window).scroll(function () {
            let scroll = $(window).scrollTop();
            /* add post toc fixed. */
            if (scroll > tocHeight) {
                $tocWidget.addClass('toc-fixed');
            } else {
                $tocWidget.removeClass('toc-fixed');
            }
        });

        
        /* 修复文章卡片 div 的宽度. */
        let fixPostCardWidth = function (srcId, targetId) {
            let srcDiv = $('#' + srcId);
            if (srcDiv.length === 0) {
                return;
            }

            let w = srcDiv.width();
            if (w >= 450) {
                w = w + 21;
            } else if (w >= 350 && w < 450) {
                w = w + 18;
            } else if (w >= 300 && w < 350) {
                w = w + 16;
            } else {
                w = w + 14;
            }
            $('#' + targetId).width(w);
        };

        // 切换TOC目录展开收缩的相关操作.
        const expandedClass = 'expanded';
        let $tocAside = $('#toc-aside');
        let $mainContent = $('#main-content');
        $('#floating-toc-btn .btn-floating').click(function () {
            if ($tocAside.hasClass(expandedClass)) {
                $tocAside.removeClass(expandedClass).hide();
                $mainContent.removeClass('l9');
            } else {
                $tocAside.addClass(expandedClass).show();
                $mainContent.addClass('l9');
            }
            fixPostCardWidth('artDetail', 'prenext-posts');
        });
        
    });
</script>

    

</main>


<script src="https://cdn.bootcss.com/mathjax/2.7.5/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script>
    MathJax.Hub.Config({
        tex2jax: {inlineMath: [['$', '$'], ['\(', '\)']]}
    });
</script>



    <footer class="page-footer bg-color">
    
        <link rel="stylesheet" href="/libs/aplayer/APlayer.min.css">
<style>
    .aplayer .aplayer-lrc p {
        
        display: none;
        
        font-size: 12px;
        font-weight: 700;
        line-height: 16px !important;
    }

    .aplayer .aplayer-lrc p.aplayer-lrc-current {
        
        display: none;
        
        font-size: 15px;
        color: #42b983;
    }

    
    .aplayer.aplayer-fixed.aplayer-narrow .aplayer-body {
        left: -66px !important;
    }

    .aplayer.aplayer-fixed.aplayer-narrow .aplayer-body:hover {
        left: 0px !important;
    }

    
</style>
<div class="">
    
    <div class="row">
        <meting-js class="col l8 offset-l2 m10 offset-m1 s12"
                   server="netease"
                   type="playlist"
                   id="503838841"
                   fixed='true'
                   autoplay='false'
                   theme='#42b983'
                   loop='all'
                   order='random'
                   preload='auto'
                   volume='0.7'
                   list-folded='true'
        >
        </meting-js>
    </div>
</div>

<script src="/libs/aplayer/APlayer.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/meting@2/dist/Meting.min.js"></script>

    
    <div class="container row center-align" style="margin-bottom: 0px !important;">
        <div class="col s12 m8 l8 copy-right">
            Copyright&nbsp;&copy;
            
                <span id="year">2020-2023</span>
            
            <span id="year">2020</span>
            <a href="/about" target="_blank">J Sir</a>
            |&nbsp;Powered by&nbsp;<a href="https://hexo.io/" target="_blank">Hexo</a>
            |&nbsp;Theme&nbsp;<a href="https://github.com/blinkfox/hexo-theme-matery" target="_blank">Matery</a>
            <br>
            
            &nbsp;<i class="fas fa-chart-area"></i>&nbsp;站点总字数:&nbsp;<span
                class="white-color">300.9k</span>&nbsp;字
            
            
            
            
            
            
            <span id="busuanzi_container_site_pv">
                |&nbsp;<i class="far fa-eye"></i>&nbsp;总访问量:&nbsp;<span id="busuanzi_value_site_pv"
                    class="white-color"></span>&nbsp;次
            </span>
            
            
            <span id="busuanzi_container_site_uv">
                |&nbsp;<i class="fas fa-users"></i>&nbsp;总访问人数:&nbsp;<span id="busuanzi_value_site_uv"
                    class="white-color"></span>&nbsp;人
            </span>
            
            <br>
            
            <br>
            
        </div>
        <div class="col s12 m4 l4 social-link social-statis">


    <a href="mailto:2065373132@qq.com" class="tooltipped" target="_blank" data-tooltip="邮件联系我" data-position="top" data-delay="50">
        <i class="fas fa-envelope-open"></i>
    </a>







    <a href="tencent://AddContact/?fromId=50&fromSubId=1&subcmd=all&uin=2065373132" class="tooltipped" target="_blank" data-tooltip="QQ联系我: 2065373132" data-position="top" data-delay="50">
        <i class="fab fa-qq"></i>
    </a>







    <a href="/atom.xml" class="tooltipped" target="_blank" data-tooltip="RSS 订阅" data-position="top" data-delay="50">
        <i class="fas fa-rss"></i>
    </a>

</div>
    </div>
</footer>

<div class="progress-bar"></div>


    <!-- 搜索遮罩框 -->
<div id="searchModal" class="modal">
    <div class="modal-content">
        <div class="search-header">
            <span class="title"><i class="fas fa-search"></i>&nbsp;&nbsp;搜索</span>
            <input type="search" id="searchInput" name="s" placeholder="请输入搜索的关键字"
                   class="search-input">
        </div>
        <div id="searchResult"></div>
    </div>
</div>

<script type="text/javascript">
$(function () {
    var searchFunc = function (path, search_id, content_id) {
        'use strict';
        $.ajax({
            url: path,
            dataType: "xml",
            success: function (xmlResponse) {
                // get the contents from search data
                var datas = $("entry", xmlResponse).map(function () {
                    return {
                        title: $("title", this).text(),
                        content: $("content", this).text(),
                        url: $("url", this).text()
                    };
                }).get();
                var $input = document.getElementById(search_id);
                var $resultContent = document.getElementById(content_id);
                $input.addEventListener('input', function () {
                    var str = '<ul class=\"search-result-list\">';
                    var keywords = this.value.trim().toLowerCase().split(/[\s\-]+/);
                    $resultContent.innerHTML = "";
                    if (this.value.trim().length <= 0) {
                        return;
                    }
                    // perform local searching
                    datas.forEach(function (data) {
                        var isMatch = true;
                        var data_title = data.title.trim().toLowerCase();
                        var data_content = data.content.trim().replace(/<[^>]+>/g, "").toLowerCase();
                        var data_url = data.url;
                        data_url = data_url.indexOf('/') === 0 ? data.url : '/' + data_url;
                        var index_title = -1;
                        var index_content = -1;
                        var first_occur = -1;
                        // only match artiles with not empty titles and contents
                        if (data_title !== '' && data_content !== '') {
                            keywords.forEach(function (keyword, i) {
                                index_title = data_title.indexOf(keyword);
                                index_content = data_content.indexOf(keyword);
                                if (index_title < 0 && index_content < 0) {
                                    isMatch = false;
                                } else {
                                    if (index_content < 0) {
                                        index_content = 0;
                                    }
                                    if (i === 0) {
                                        first_occur = index_content;
                                    }
                                }
                            });
                        }
                        // show search results
                        if (isMatch) {
                            str += "<li><a href='" + data_url + "' class='search-result-title'>" + data_title + "</a>";
                            var content = data.content.trim().replace(/<[^>]+>/g, "");
                            if (first_occur >= 0) {
                                // cut out 100 characters
                                var start = first_occur - 20;
                                var end = first_occur + 80;
                                if (start < 0) {
                                    start = 0;
                                }
                                if (start === 0) {
                                    end = 100;
                                }
                                if (end > content.length) {
                                    end = content.length;
                                }
                                var match_content = content.substr(start, end);
                                // highlight all keywords
                                keywords.forEach(function (keyword) {
                                    var regS = new RegExp(keyword, "gi");
                                    match_content = match_content.replace(regS, "<em class=\"search-keyword\">" + keyword + "</em>");
                                });

                                str += "<p class=\"search-result\">" + match_content + "...</p>"
                            }
                            str += "</li>";
                        }
                    });
                    str += "</ul>";
                    $resultContent.innerHTML = str;
                });
            }
        });
    };

    searchFunc('/search.xml', 'searchInput', 'searchResult');
});
</script>

    <!-- 回到顶部按钮 -->
<div id="backTop" class="top-scroll">
    <a class="btn-floating btn-large waves-effect waves-light" href="#!">
        <i class="fas fa-arrow-up"></i>
    </a>
</div>


    <script src="/libs/materialize/materialize.min.js"></script>
    <script src="/libs/masonry/masonry.pkgd.min.js"></script>
    <script src="/libs/aos/aos.js"></script>
    <script src="/libs/scrollprogress/scrollProgress.min.js"></script>
    <script src="/libs/lightGallery/js/lightgallery-all.min.js"></script>
    <script src="/js/matery.js"></script>

    <!-- Baidu Analytics -->

    <!-- Baidu Push -->

<script>
    (function () {
        var bp = document.createElement('script');
        var curProtocol = window.location.protocol.split(':')[0];
        if (curProtocol === 'https') {
            bp.src = 'https://zz.bdstatic.com/linksubmit/push.js';
        } else {
            bp.src = 'http://push.zhanzhang.baidu.com/push.js';
        }
        var s = document.getElementsByTagName("script")[0];
        s.parentNode.insertBefore(bp, s);
    })();
</script>

    
    <script src="/libs/others/clicklove.js" async="async"></script>
    
    
    <script async src="/libs/others/busuanzi.pure.mini.js"></script>
    

    

    

	
    

    

    

    
    <script src="/libs/instantpage/instantpage.js" type="module"></script>
    

<script src="/live2dw/lib/L2Dwidget.min.js?094cbace49a39548bed64abff5988b05"></script><script>L2Dwidget.init({"pluginRootPath":"live2dw/","pluginJsPath":"lib/","pluginModelPath":"assets/","tagMode":false,"debug":false,"model":{"jsonPath":"live2d-widget-model-hibiki"},"display":{"position":"right","width":145,"height":315},"mobile":{"show":true,"scale":0.5},"react":{"opacityDefault":0.7,"opacityOnHover":0.8},"log":false});</script></body>

</html>
